# core/formalization/rl/rl_agent.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
import numpy as np
from typing import Dict, List, Tuple, Any, Optional

from core.formalization.action_space import ActionType
from core.formalization.rl.state import State
from core.formalization.rl.reward import Reward
from core.formalization.rl.action_mask import ActionMask
from core.formalization.rl.exp import StepExp, ExpBuffer
from utils.logger import Logger
from llm.llm_wrapper import LLMWrapper
from llm.auxiliary import Auxiliary


class SharedFeatureExtractor(nn.Module):
    def __init__(self, state_dim: int, hidden_dims: List[int] = [256, 128]):
        super(SharedFeatureExtractor, self).__init__()

        layers = []
        input_dim = state_dim

        for hidden_dim in hidden_dims:
            layers.extend(
                [nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Dropout(0.1)]
            )
            input_dim = hidden_dim

        self.network = nn.Sequential(*layers)
        self.output_dim = hidden_dims[-1]
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.orthogonal_(m.weight, gain=0.01)
                nn.init.constant_(m.bias, 0)

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        return self.network(state)


class PolicyHead(nn.Module):
    def __init__(self, input_dim: int, action_dim: int):
        super(PolicyHead, self).__init__()
        self.action_head = nn.Linear(input_dim, action_dim)
        self._init_weights()

    def _init_weights(self):
        nn.init.orthogonal_(self.action_head.weight, gain=0.01)
        nn.init.constant_(self.action_head.bias, 0)

    def forward(
        self, features: torch.Tensor, mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        logits = self.action_head(features)

        if mask is not None:
            logits = logits + (mask - 1) * 1e9

        return logits


class ValueHead(nn.Module):
    def __init__(self, feature_dim: int):
        super(ValueHead, self).__init__()

        self.feature_layer = nn.Sequential(nn.Linear(feature_dim, 64), nn.ReLU())

        self.reward_layer = nn.Sequential(nn.Linear(4, 16), nn.ReLU())

        self.combined_layer = nn.Sequential(
            nn.Linear(64 + 16, 32), nn.ReLU(), nn.Linear(32, 1)
        )

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.orthogonal_(m.weight)
                nn.init.constant_(m.bias, 0)

    def forward(
        self, features: torch.Tensor, reward_components: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        feature_embedding = self.feature_layer(features)
        reward_embedding = self.reward_layer(reward_components)
        combined = torch.cat([feature_embedding, reward_embedding], dim=1)
        value = self.combined_layer(combined)
        return value


class ACNet(nn.Module):

    def __init__(
        self, state_dim: int, action_dim: int, hidden_dims: List[int] = [256, 128]
    ):
        super(ACNet, self).__init__()

        self.feature_extractor = SharedFeatureExtractor(state_dim, hidden_dims)
        feature_dim = self.feature_extractor.output_dim

        self.policy_head = PolicyHead(feature_dim, action_dim)
        self.value_head = ValueHead(feature_dim)

    def forward_actor(
        self,
        state: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
    ):
        features = self.feature_extractor(state)
        logits = self.policy_head(features, mask)
        return logits

    def forward_critic(
        self,
        state: torch.Tensor,
        reward_components: torch.Tensor,
    ):
        features = self.feature_extractor(state)
        value = self.value_head(features, reward_components)
        return value

    def forward(
        self,
        state: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        reward_components: Optional[torch.Tensor] = None,
    ):
        features = self.feature_extractor(state)
        logits = self.policy_head(features, mask)
        value = self.value_head(features, reward_components)
        return logits, value

    def get_action_and_log_prob(
        self, state: torch.Tensor, mask: Optional[torch.Tensor] = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        logits = self.forward_actor(state, mask)
        dist = Categorical(logits=logits)
        action = dist.sample()
        log_prob = dist.log_prob(action)

        return action, log_prob

    def get_log_prob(
        self,
        state: torch.Tensor,
        action: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        logits = self.forward_actor(state, mask)
        dist = Categorical(logits=logits)
        return dist.log_prob(action)

    def evaluate(
        self,
        state: torch.Tensor,
        action: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
        reward_components: Optional[torch.Tensor] = None,
    ):
        logits, value = self.forward(state, mask, reward_components)
        dist = Categorical(logits=logits)
        log_prob = dist.log_prob(action)
        entropy = dist.entropy()

        return log_prob, value, entropy
    
class PPO:
    def __init__(
        self,
        logger: Logger,
        llm: LLMWrapper,
        auxiliary: Auxiliary,
        state_dim: int,
        action_dim: int,
        actions: List,
        config={}
    ):
        self.logger = logger
        self.llm = llm
        self.auxiliary = auxiliary

        self.state_dim = state_dim
        self.action_dim = action_dim
        self.config = config
        self.actions = actions

        self.lr = self.config.get('lr', 3e-4)
        self.discount_factor = self.config.get('9', 0.99)
        self.smooth_factor = self.config.get('gae_smooth_factor', 0.95)
        self.eps_clip = self.config.get('eps_clip', 0.2)
        self.entropy_coef = self.config.get('entropy_coef', 0.01)
        self.value_coef = self.config.get('value_coef', 0.5)
        self.max_grad_norm = self.config.get('max_grad_norm', 0.5)
        self.exp_reward_weights = config.get("exp_reward_weights", {
            "rs": 1.0,
            "re": 0.2,
            "rh": 0.5,
            "rd": 1.0
        })
        self.exp_factor = 1
        self.total_update_count = 0

        hidden_dims = [256, 128]

        self.logger.info("RL agent start to init...")
        self.ac_network = ACNet(state_dim, action_dim, hidden_dims)
        self.optimizer = optim.Adam(self.ac_network.parameters(), lr=self.lr)
        
        self.buffer = ExpBuffer(self.config.get('buffer_capacity', 1000))

        self.state_calculator = State(logger, llm, auxiliary, actions, config)
        self.reward_calculator = Reward(logger, llm, auxiliary, config)
        
        self.action_mask = ActionMask(logger, llm, actions)

    def update_exp_factor(self, exp_factor):
        self.exp_factor = exp_factor

    def compute_exp_reward(self, reward_components):
        return sum(self.exp_reward_weights.get(k, 0) * v for k, v in reward_components.items())

    def compute_action_mask(self, current_text: str, context=None):
        return self.action_mask.compute_action_mask(current_text, context)

    def select_action(self, state: np.ndarray, mask: np.ndarray, training: bool = True):
        with torch.no_grad():
            state_tensor = torch.FloatTensor(state).unsqueeze(0)
            mask_tensor = torch.FloatTensor(mask).unsqueeze(0)

            logits = self.ac_network.forward_actor(state_tensor, mask_tensor)

            if training:
                exploration_prob = self.get_exploration_rate()
                
                if np.random.random() < exploration_prob:
                    valid_actions = np.where(mask == 1)[0]
                    if len(valid_actions) > 0:
                        action_int = np.random.choice(valid_actions)
                        action = torch.tensor([action_int])
                        log_prob = self.ac_network.get_log_prob(state_tensor, action, mask_tensor)
                        log_prob_float = log_prob.item()
                        return action_int, log_prob_float
                
                action, log_prob = self.ac_network.get_action_and_log_prob(state_tensor, mask_tensor)
            else:
                action = torch.argmax(logits, dim=-1)
                log_prob = self.ac_network.get_log_prob(state_tensor, action, mask_tensor)

            action_int = action.item()
            log_prob_float = log_prob.item()

            return action_int, log_prob_float

    def get_exploration_rate(self):
        init_exploration = self.config.get('init_exploration', 0.5)
        min_exploration = self.config.get('min_exploration', 0.05)
        exploration_decay = self.config.get('exploration_decay', 0.995)
        
        current_exploration = max(
            min_exploration, 
            init_exploration * (exploration_decay ** self.total_update_count)
        )
        
        return current_exploration

    def compute_state(self, original_query: str, cur_query: str, interaction_history: Dict):
        return self.state_calculator.compute_state_vector(original_query, cur_query, interaction_history)

    def compute_reward(self, cur_query: str, target: str, cur_response: str, original_query: str, last_response: str, step_count: int):
        return self.reward_calculator.compute_reward(cur_query, target, cur_response, original_query, last_response, step_count)

    def store_experience(self, 
                        state: np.ndarray,
                        action: int,
                        log_prob: float,
                        value: float,
                        reward_info: Dict,
                        total_reward: float,
                        done: bool,
                        mask: np.ndarray):
        reward_components = {}
        if 'reward' in reward_info:
            reward_components = {
                'rs': reward_info['reward'].get('rs', 0.0),
                're': reward_info['reward'].get('re', 0.0),
                'rh': reward_info['reward'].get('rh', 0.0),
                'rd': reward_info['reward'].get('rd', 0.0)
            }
        
        experience = StepExp(
            state=state,
            action=action,
            log_prob=log_prob,
            value=value,
            exp_reward=total_reward,
            done=done,
            mask=mask,
            reward_components=reward_components,
        )
        
        self.buffer.add_exp(experience)

    def update(self, batch_size: int = 4, epochs: int = 8) -> Dict[str, float]:
        self.logger.info("RL agent start to update network")
        total_loss = 0
        total_actor_loss = 0
        total_critic_loss = 0
        total_entropy = 0
        
        for _ in range(epochs):
            batch = self.buffer.get_training_batch(batch_size)
            
            if batch is None:
                continue
            
            advantages = (batch.advantages - batch.advantages.mean()) / (batch.advantages.std() + 1e-8)
            log_probs, values, entropy = self.ac_network.evaluate(
                batch.states, batch.actions, batch.masks, batch.reward_components
            )
            
            ratio = torch.exp(log_probs - batch.old_log_probs)
            surr1 = ratio * advantages
            surr2 = torch.clamp(ratio, 1 - self.eps_clip, 1 + self.eps_clip) * advantages
            actor_loss = -torch.min(surr1, surr2).mean()
            
            exp_rewards = []
            for i in range(batch.reward_components.size(0)):
                components_dict = {
                    'rs': batch.reward_components[i, 0].item(),
                    're': batch.reward_components[i, 1].item(),
                    'rh': batch.reward_components[i, 2].item(),
                    'rd': batch.reward_components[i, 3].item()
                }
                exp_reward = self.compute_exp_reward(components_dict)
                exp_rewards.append(exp_reward)
            exp_rewards_tensor = torch.FloatTensor(exp_rewards).to(batch.returns.device)
            mixed_returns = self.exp_factor * exp_rewards_tensor + (1 - self.exp_factor) * batch.returns
            
            value_loss = F.mse_loss(values.squeeze(), mixed_returns)

            loss = (
                actor_loss + 
                self.value_coef * value_loss - 
                self.entropy_coef * entropy.mean()
            )
            
            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.ac_network.parameters(), self.max_grad_norm)
            self.optimizer.step()
            
            total_loss += loss.item()
            total_actor_loss += actor_loss.item()
            total_critic_loss += value_loss.item()
            total_entropy += entropy.mean().item()

        self.total_update_count += 1

        if epochs > 0:
            avg_loss = total_loss / epochs
            avg_actor_loss = total_actor_loss / epochs
            avg_critic_loss = total_critic_loss / epochs
            avg_entropy = total_entropy / epochs
            
            self.logger.info(f"Update stats - Loss: {avg_loss:.4f}, Actor: {avg_actor_loss:.4f}, "
                            f"Critic: {avg_critic_loss:.4f}, Entropy: {avg_entropy:.4f}")
            
            return {
                'loss': avg_loss,
                'actor_loss': avg_actor_loss,
                'critic_loss': avg_critic_loss,
                'entropy': avg_entropy,
                'epochs': epochs
            }
        
        return {}
    
    def save_model(self, filepath: str):
        torch.save({
            'ac_network_state_dict': self.ac_network.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'config': self.config
        }, filepath)
        
        self.logger.info(f"Model saved to {filepath}")
    
    def load_model(self, filepath: str):
        self.logger.info(f"Start to load Model from {filepath}")
        checkpoint = torch.load(filepath)
        
        self.ac_network.load_state_dict(checkpoint['ac_network_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        
        self.logger.info(f"Model loaded from {filepath}")

    def clear_buffer(self):
        self.buffer.clear()
